package com.innovation.ic.im.end.web.config;

import com.alibaba.fastjson.JSONObject;
import com.google.common.base.Strings;
import com.innovation.ic.b1b.framework.util.StringUtils;
import com.innovation.ic.im.end.base.pojo.constant.Constants;
import com.innovation.ic.im.end.base.pojo.constant.UserLoginType;
import com.innovation.ic.im.end.base.value.config.FilterParamConfig;
import org.apache.catalina.session.StandardSessionFacade;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Configuration;
import javax.annotation.PostConstruct;
import javax.annotation.Resource;
import javax.servlet.http.HttpSession;
import javax.websocket.HandshakeResponse;
import javax.websocket.server.HandshakeRequest;
import javax.websocket.server.ServerEndpointConfig;
import java.util.Enumeration;
import java.util.List;
import java.util.Map;

/**
 * 在websocket取HttpSession
 */
@Configuration
public class GetHttpSessionConfigurator extends ServerEndpointConfig.Configurator {
    private static final Logger log = LoggerFactory.getLogger(GetHttpSessionConfigurator.class);

    @Resource
    private FilterParamConfig filterParamConfig;

    private static FilterParamConfig config;

    @PostConstruct
    public void init() {
        config = filterParamConfig;
    }

    /** 修改握手,就是在握手协议建立之前修改其中携带的内容 */
    @Override
    public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
        try {
            Map<String, Object> attributes = sec.getUserProperties();
            /*如果没有监听器,那么这里获取到的HttpSession是null*/
            StandardSessionFacade ssf = (StandardSessionFacade) request.getHttpSession();
            if (ssf != null) {
                HttpSession session = (HttpSession) request.getHttpSession();
                sec.getUserProperties().put(HttpSession.class.getName(), session);
                log.info("获取到的SessionID：{}", session.getId());
                attributes.put(UserLoginType.IP_ADDR, session.getAttribute("ip"));

                // 获取websocket的header中的验证参数
                String headerString = request.getHeaders().get(config.getWebsocketHeaderName().toLowerCase()).get(0);

                // 获取clientId
                String clientId = getClientId(headerString);
                if(!Strings.isNullOrEmpty(clientId)){
                    session.setAttribute(Constants.CLIENT_ID, clientId);
                }

                Enumeration<String> names = session.getAttributeNames();
                while (names.hasMoreElements()) {
                    String name = names.nextElement();
                    attributes.put(name, session.getAttribute(name));
                }
            }

            // 获取前端请求头中传递的Sec-WebSocket-Protocol
            List<String> list = request.getHeaders().get(Constants.SEC_WEBSOCKET_PROTOCOL);
            // 当Sec-WebSocket-Protocol请求头不为空时,需要返回给前端相同的响应
            if(list != null && list.size() > 0){
                response.getHeaders().put(Constants.SEC_WEBSOCKET_PROTOCOL, list);
            }

            super.modifyHandshake(sec, request, response);
        }catch (Exception e){
            e.printStackTrace();
        }
    }

    /**
     * 获取clientId
     * @param headerString header中的字符串
     * @return 返回clientId
     */
    private String getClientId(String headerString){
        String clientId = null;

        if(!StringUtils.isEmpty(headerString)){
            // 将前端传过来的单引号替换为等号供后续处理
            headerString = headerString.replaceAll("'", "=");
        }

        JSONObject jsonObject = new JSONObject();
        String[] params = headerString.split("&");
        for (String param : params) {
            String[] split = param.split("=");
            jsonObject.put(split[0], split[1]);
        }

        if(!jsonObject.isEmpty()){
            clientId = jsonObject.getString(config.getAuthorization());
        }

        return clientId;
    }
}